{
   "cells": [
      {
         "cell_type": "markdown",
         "metadata": {},
         "source": [
            "### Zero-shot cell type annotation\n",
            "Given the gene expression profiles of the cells, as well as textual descriptions of alternative cell types, LangCell can automatically perform cell type annotation."
         ]
      },
      {
         "cell_type": "code",
         "execution_count": 1,
         "metadata": {},
         "outputs": [
            {
               "name": "stderr",
               "output_type": "stream",
               "text": [
                  "/AIRvePFS/dair/conda_envs/biomed/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
                  "  from .autonotebook import tqdm as notebook_tqdm\n",
                  "/AIRvePFS/dair/conda_envs/biomed/lib/python3.9/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations\n",
                  "  warnings.warn(\n"
               ]
            }
         ],
         "source": [
            "import os\n",
            "import sys\n",
            "parent = os.path.dirname(os.path.abspath(''))\n",
            "sys.path.append(parent)\n",
            "os.chdir(parent)\n",
            "from open_biomed.core.pipeline import InferencePipeline\n",
            "from open_biomed.data import Cell, Text\n",
            "from datasets import load_from_disk\n",
            "import json\n",
            "from open_biomed.data import Cell, Text\n",
            "from sklearn.metrics import classification_report"
         ]
      },
      {
         "cell_type": "code",
         "execution_count": null,
         "metadata": {},
         "outputs": [
            {
               "name": "stderr",
               "output_type": "stream",
               "text": [
                  "05/23/2025 18:35:11 - INFO - root - The config of this process is:\n",
                  "{\n",
                  "    \"model\": {\n",
                  "        \"name\": \"langcell\",\n",
                  "        \"cell_model\": \"./checkpoints/langcell/cell_bert\",\n",
                  "        \"cell_proj\": \"./checkpoints/langcell/cell_proj.bin\",\n",
                  "        \"text_tokenizer\": \"./checkpoints/langcell/pubmedbert-base\",\n",
                  "        \"text_model\": \"./checkpoints/langcell/text_bert\",\n",
                  "        \"text_proj\": \"./checkpoints/langcell/text_proj.bin\",\n",
                  "        \"ctm_head\": \"./checkpoints/langcell/ctm_head.bin\"\n",
                  "    },\n",
                  "    \"task\": \"cell_annotation\",\n",
                  "    \"model_ckpt\": \"\",\n",
                  "    \"device\": \"cuda:2\",\n",
                  "    \"logging_level\": \"info\"\n",
                  "}\n",
                  "Some weights of BertModel were not initialized from the model checkpoint at ./checkpoints/langcell/cell_bert and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']\n",
                  "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
               ]
            }
         ],
         "source": [
            "# Load the model\n",
            "cfg_path = \"./configs/model/langcell.yaml\"\n",
            "pipeline = InferencePipeline(model='langcell', task='cell_annotation', device='cuda:2')"
         ]
      },
      {
         "cell_type": "code",
         "execution_count": 7,
         "metadata": {},
         "outputs": [
            {
               "name": "stdout",
               "output_type": "stream",
               "text": [
                  "B cells ---- cell type: b cell. a lymphocyte of b lineage that is capable of b cell mediated immunity.;  \n",
                  "\n",
                  "CD8 T cells ---- cell type: cd8-positive, alpha-beta t cell. a t cell expressing an alpha-beta t cell receptor and the cd8 coreceptor.;  \n",
                  "\n",
                  "CD14+ Monocytes ---- cell type: cd14-positive monocyte. a monocyte that expresses cd14 and is negative for the lineage markers cd3, cd19, and cd20.;  \n",
                  "\n",
                  "Dendritic Cells ---- cell type: dendritic cell. a cell of hematopoietic origin, typically resident in particular tissues, specialized in the uptake, processing, and transport of antigens to lymph nodes for the purpose of stimulating an immune response via t cell activation. these cells are lineage negative (cd3-negative, cd19-negative, cd34-negative, and cd56-negative).;  \n",
                  "\n",
                  "NK cells ---- cell type: natural killer cell. a lymphocyte that can spontaneously kill a variety of target cells without prior antigenic activation via germline encoded activation receptors and also regulate immune responses via cytokine release and direct contact with other cells.;  \n",
                  "\n",
                  "Megakaryocytes ---- cell type: megakaryocyte. a large hematopoietic cell (50 to 100 micron) with a lobated nucleus. once mature, this cell undergoes multiple rounds of endomitosis and cytoplasmic restructuring to allow platelet formation and release.;  \n",
                  "\n",
                  "FCGR3A+ Monocytes ---- cell type: cd14-low, cd16-positive monocyte. a patrolling monocyte that is cd14-low and cd16-positive.;  \n",
                  "\n",
                  "CD4 T cells ---- cell type: cd4-positive, alpha-beta t cell. a mature alpha-beta t cell that expresses an alpha-beta t cell receptor and the cd4 coreceptor.;  \n",
                  "\n"
               ]
            }
         ],
         "source": [
            "# Load the dataset\n",
            "# Download data: https://drive.google.com/drive/folders/1cuhVG9v0YoAnjW-t_WMpQQguajumCBTp\n",
            "dataset = load_from_disk('/path/to/pbmc10k.dataset')\n",
            "type2text = json.load(open('/path/to/type2text.json'))\n",
            "for type in type2text:\n",
            "    print(type, '----', type2text[type], '\\n')"
         ]
      },
      {
         "cell_type": "code",
         "execution_count": 8,
         "metadata": {},
         "outputs": [],
         "source": [
            "# random sample\n",
            "dataset = dataset.shuffle(seed=42).select(range(2000))"
         ]
      },
      {
         "cell_type": "code",
         "execution_count": null,
         "metadata": {},
         "outputs": [],
         "source": [
            "# Organize data into specific formats as model inputs\n",
            "texts = []\n",
            "type2label = {}\n",
            "labels = []\n",
            "for type in type2text:\n",
            "    texts.append(Text.from_str(type2text[type]))\n",
            "    type2label[type] = len(texts) - 1\n",
            "input = {'cell': [], 'class_texts': [], 'label': []}\n",
            "for data in dataset:\n",
            "    input['cell'].append(Cell.from_sequence(data['input_ids']))\n",
            "    input['class_texts'].append(texts)\n",
            "    input['label'].append(type2label[data['str_labels']])\n",
            "    labels.append(type2label[data['str_labels']])"
         ]
      },
      {
         "cell_type": "code",
         "execution_count": null,
         "metadata": {},
         "outputs": [
            {
               "name": "stderr",
               "output_type": "stream",
               "text": [
                  "Inference Steps:   0%|          | 0/2000 [00:00<?, ?it/s]/AIRvePFS/dair/luoyz-data/projects/OpenBioMed/OpenBioMed_arch/open_biomed/models/cell/langcell/langcell_utils.py:1000: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
                  "  batch = {'cell': torch.tensor(batch['input_ids'], dtype=torch.int64),\n",
                  "/AIRvePFS/dair/luoyz-data/projects/OpenBioMed/OpenBioMed_arch/open_biomed/models/cell/langcell/langcell_utils.py:1001: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
                  "  'attention_mask': torch.tensor(batch['attention_mask'], dtype=torch.int64),\n",
                  "Inference Steps: 100%|██████████| 2000/2000 [03:56<00:00,  8.44it/s]\n"
               ]
            }
         ],
         "source": [
            "# Predict the cell type of each cell using the model\n",
            "preds, _ = pipeline.run(batch_size=1, **input)\n",
            "preds = [p.item() for p in preds]"
         ]
      },
      {
         "cell_type": "code",
         "execution_count": null,
         "metadata": {},
         "outputs": [
            {
               "name": "stdout",
               "output_type": "stream",
               "text": [
                  "                   precision    recall  f1-score   support\n",
                  "\n",
                  "          B cells       1.00      0.98      0.99       279\n",
                  "      CD8 T cells       0.59      0.95      0.73       260\n",
                  "  CD14+ Monocytes       0.96      0.99      0.98       387\n",
                  "  Dendritic Cells       1.00      0.82      0.90        67\n",
                  "         NK cells       0.82      0.98      0.90        57\n",
                  "   Megakaryocytes       0.82      0.90      0.86        20\n",
                  "FCGR3A+ Monocytes       1.00      0.82      0.90        66\n",
                  "      CD4 T cells       0.98      0.81      0.89       864\n",
                  "\n",
                  "         accuracy                           0.89      2000\n",
                  "        macro avg       0.90      0.91      0.89      2000\n",
                  "     weighted avg       0.93      0.89      0.90      2000\n",
                  "\n"
               ]
            }
         ],
         "source": [
            "# Analyze the results\n",
            "print(classification_report(labels, preds, labels=range(len(type2text)), target_names=type2text.keys()))"
         ]
      }
   ],
   "metadata": {
      "kernelspec": {
         "display_name": "biomed",
         "language": "python",
         "name": "python3"
      },
      "language_info": {
         "codemirror_mode": {
            "name": "ipython",
            "version": 3
         },
         "file_extension": ".py",
         "mimetype": "text/x-python",
         "name": "python",
         "nbconvert_exporter": "python",
         "pygments_lexer": "ipython3",
         "version": "3.9.7"
      }
   },
   "nbformat": 4,
   "nbformat_minor": 2
}